import torch

##### Locally Self-Balancing Sampler
class LSBS():

    ##### Function initializing elements used 
    ##### to compute and sample from Q(X'|X))
    # batch     - number of chains
    # n         - size of square lattice
    # lamda     - scalar
    # alpha     - n x n matrix of bias coefficients
    # func      - 'lsb1', 'lsb2' 
    # lr        - learning rate
    # start_x   - batch x n x n initial tensor
    # theta     - scalar vector for combining balancing functions for 'lsb1'
    #           state dictionary of parameters of the neural net for 'lsb2'
    # gamma     - scalar used to ensure monotonicity of balancng function
    # factor    - integer indicating the number of blocks in the monotonic network
    # block     - integer indicating the size of each block in the monotonic network
    # max_value - maximum value for random numbers in the monotonicity regularizer
    # device    - device to run the simulation
    # burn_in   - Burn-in - True, Sampling - False
    def __init__(self, batch, n, lamda, alpha, func, lr=1e-3, start_x=None, theta=None, \
                gamma=.1, factor=20, block=20, max_value=2, device=None, burn_in=True):

        self.batch = batch
        self.n = n
        self.lamda = lamda
        self.device = device
        self.alpha = alpha.to(self.device)
        self.method = func
        self.lr = lr
        self.eps = 1e-3
        # Params for neural net
        self.gamma = gamma
        self.factor = factor
        self.max_value = max_value
        self.block = block
        self.hidden = self.factor * self.block

        self.burn_in = burn_in

        if self.method == 'lsb1': # Learning to select
            # Initializing theta to prefer sqrt balancing function
            self.velocity = torch.tensor(0., requires_grad=False, device=self.device)
            self.softmax = torch.nn.Softmax(dim=0)
            self.func = self.linear_normalized
            if theta is None:
                self.theta = torch.zeros((4,), requires_grad=True, device=self.device)
            else: 
                self.theta = theta.clone()
        elif self.method == 'lsb2': # Learning the balancing function
            self.init_lr = 1e-2
            self.init_iter = 5000
            self.model = MonoNet(self.hidden, self.block, device)
            self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9)
            self.func = self.nonlinear
            if theta is not None:
                self.model.load_state_dict(theta)
            else:
                input_data = torch.rand((self.batch, 1, 1), requires_grad=False, device=self.device) * self.max_value
                optimizer = torch.optim.SGD(self.model.parameters(), lr=self.init_lr, momentum=0.9)
                for i in range(self.init_iter):
                    optimizer.zero_grad()
                    loss = torch.mean((self.func(input_data) - self.maxim(input_data))**2)
                    if i % 100 == 0:
                        print(i, loss.item())
                    loss.backward(retain_graph=True)
                    optimizer.step()
                print('\nInitialization of lsb2 done!\n')
        else:
            raise Exception('Wrong selection of balancing function!')

        # Computing invariants
        if start_x == None:
            self.x = torch.randint(2, (batch, n, n), requires_grad=False, device=self.device)*2 - 1
        else:
            self.x = start_x.clone()

        temp_up = torch.cat((self.x[:, 1:, :], self.x[:, :1, :]), 1)       # Moving up 1 row
        temp_low = torch.cat((self.x[:, -1:,:], self.x[:, :-1, :]), 1)     # Moving down 1 row
        temp_left = torch.cat((self.x[:, :, 1:], self.x[:, :, :1]), 2)     # Moving left 1 col
        temp_right = torch.cat((self.x[:, :, -1:], self.x[:, :, :-1]), 2)  # Moving right 1 col
        self.ratio = torch.exp(-2. * self.alpha * self.x
                                -2. * self.lamda * self.x
                                * (temp_up + temp_low + temp_left + temp_right))
        # print(torch.median(self.ratio), torch.max(self.ratio))
        self.g_ratio = self.func(self.ratio)
        self.sum_rows = torch.sum(self.g_ratio, 2)
        self.sum_rows = torch.maximum(self.sum_rows, torch.tensor(0., requires_grad=False, device=self.device)) # Check to avoid numerical issues
        self.norm_const = torch.sum(self.sum_rows, 1)
        self.norm_const = torch.maximum(self.norm_const, torch.tensor(0., requires_grad=False, device=self.device)) # Check to avoid numerical issues
        self.accept = None
        self.obj = None


    ##### Function used to sample a new configuration (during burn-in)
    def sample_burn_in(self):
        
        #### Resetting at the beginning of each iteration
        self.x[(self.batch // 2 + self.batch % 2):, :, :] = torch.randint(2, (self.batch // 2, self.n, self.n), requires_grad=False, device=self.device)*2 - 1
        # Updating ratio, g_ratio, sum_rows and norm_const
        temp_up = torch.cat((self.x[:, 1:, :], self.x[:, :1, :]), 1)       # Moving up 1 row
        temp_low = torch.cat((self.x[:, -1:,:], self.x[:, :-1, :]), 1)     # Moving down 1 row
        temp_left = torch.cat((self.x[:, :, 1:], self.x[:, :, :1]), 2)     # Moving left 1 col
        temp_right = torch.cat((self.x[:, :, -1:], self.x[:, :, :-1]), 2)  # Moving right 1 col
        self.ratio = torch.exp(-2. * self.alpha * self.x                                                        \
                                -2. * self.lamda * self.x                                                       \
                                * (temp_up + temp_low + temp_left + temp_right))
        # print(torch.median(self.ratio), torch.max(self.ratio))
        self.g_ratio = self.func(self.ratio)
        self.sum_rows = torch.sum(self.g_ratio, 2)
        self.sum_rows = torch.maximum(self.sum_rows, torch.tensor(0., requires_grad=False, device=self.device)) # Check to avoid numerical issues
        self.norm_const = torch.sum(self.sum_rows, 1)
        self.norm_const = torch.maximum(self.norm_const, torch.tensor(0., requires_grad=False, device=self.device)) # Check to avoid numerical issues
        # Computing log P_tilde(X)
        self.log_p_tilde = torch.sum(self.alpha * self.x, dim=[1, 2])                                           \
                                            + 0.5 * self.lamda * torch.sum(self.x                               \
                                            * (temp_up + temp_low + temp_left + temp_right), dim=[1, 2])
        # Backup X
        self.x_temp = self.x.clone()
        # Initialize w, A, T
        self.w = torch.ones((self.batch,), device=self.device)
        self.A = torch.ones((self.batch,), device=self.device)
        self.T = torch.zeros((self.batch,), device=self.device)
        ####

        batch_id = torch.arange(self.batch, device=self.device)

        # Sampling a pair of indices
        probs = self.sum_rows / torch.sum(self.sum_rows, 1, keepdim=True)
        self.i = torch.multinomial(probs, 1)[:, 0]
        temp = torch.arange(0, self.batch, device=self.device)
        probs = self.g_ratio[temp, self.i, :] / torch.sum(self.g_ratio[temp, self.i, :], 1, keepdim=True)
        self.j = torch.multinomial(probs, 1)[:, 0]

        # Compute the neighbours
        self.i_up = self.i - 1
        self.i_up[self.i_up < 0] = self.n - 1
        self.i_low = self.i + 1
        self.i_low[self.i_low == self.n] = 0
        self.j_left = self.j - 1
        self.j_left[self.j_left < 0] = self.n - 1
        self.j_right = self.j + 1
        self.j_right[self.j_right == self.n] = 0

        # Compute new normalization constant Z(X')
        delta_i = self.func(1./self.ratio[batch_id, self.i, self.j])                                                \
                - self.g_ratio[batch_id, self.i, self.j]                                                            \
                - self.g_ratio[batch_id, self.i, self.j_left]                                                       \
                - self.g_ratio[batch_id, self.i, self.j_right]                                                      \
                + self.func(self.ratio[batch_id, self.i, self.j_left]                                               \
                            * torch.exp(4. * self.lamda * self.x[batch_id, self.i, self.j_left]                     \
                                        * self.x[batch_id, self.i, self.j]))                                        \
                + self.func(self.ratio[batch_id, self.i, self.j_right]                                              \
                            * torch.exp(4. * self.lamda * self.x[batch_id, self.i, self.j_right]                    \
                                        * self.x[batch_id, self.i, self.j]))                    
        delta_i_up = - self.g_ratio[batch_id, self.i_up, self.j]                                                    \
                    + self.func(self.ratio[batch_id, self.i_up, self.j]                                             \
                                * torch.exp(4. * self.lamda * self.x[batch_id, self.i_up, self.j]                   \
                                            * self.x[batch_id, self.i, self.j]))
        delta_i_low = - self.g_ratio[batch_id, self.i_low, self.j]                                                  \
                    + self.func(self.ratio[batch_id, self.i_low, self.j]                                            \
                                * torch.exp(4. * self.lamda * self.x[batch_id, self.i_low, self.j]                  \
                                            * self.x[batch_id, self.i, self.j]))
        norm_const_new = self.norm_const + delta_i + delta_i_up + delta_i_low

        # Compute acceptance score
        self.accept = torch.minimum(torch.tensor(1., device=self.device), self.norm_const / norm_const_new)

        # Compute proposal score
        self.proposal = self.g_ratio[batch_id, self.i, self.j] / self.norm_const

        # Compute unmnormalized true score
        self.log_p_tilde = self.log_p_tilde                                                                        \
                           - 2. * self.alpha[batch_id, self.i, self.j] * self.x[batch_id, self.i, self.j]          \
                           - 2. * self.lamda * (self.x[batch_id, self.i, self.j_left]                              \
                                                + self.x[batch_id, self.i, self.j_right]                           \
                                                + self.x[batch_id, self.i_up, self.j]                              \
                                                + self.x[batch_id, self.i_low, self.j])

        # Update w, A, T
        self.w = self.w * (self.proposal / self.proposal.detach())
        self.A = self.A * self.accept
        self.T = self.T + torch.log(self.accept) + torch.log(self.proposal) - self.log_p_tilde.detach()

        # Acceptance
        self.ratio[batch_id, self.i, self.j] = 1./self.ratio[batch_id, self.i, self.j]
        self.ratio[batch_id, self.i_up, self.j] = self.ratio[batch_id, self.i_up, self.j]                           \
                        * torch.exp(4. * self.lamda * self.x[batch_id, self.i_up, self.j]                           \
                                                    * self.x[batch_id, self.i, self.j])
        self.ratio[batch_id, self.i_low, self.j] = self.ratio[batch_id, self.i_low, self.j]                         \
                        * torch.exp(4. * self.lamda * self.x[batch_id, self.i_low, self.j]                          \
                                                    * self.x[batch_id, self.i, self.j])
        self.ratio[batch_id, self.i, self.j_left] = self.ratio[batch_id, self.i, self.j_left]                       \
                        * torch.exp(4. * self.lamda * self.x[batch_id, self.i, self.j_left]                         \
                                                    * self.x[batch_id, self.i, self.j])
        self.ratio[batch_id, self.i, self.j_right] = self.ratio[batch_id, self.i, self.j_right]                     \
                        * torch.exp(4. * self.lamda * self.x[batch_id, self.i, self.j_right]                        \
                                                    * self.x[batch_id, self.i, self.j])

        self.g_ratio[batch_id, self.i, self.j] = self.func(self.ratio[batch_id, self.i, self.j])
        self.g_ratio[batch_id, self.i_up, self.j] = self.func(self.ratio[batch_id, self.i_up, self.j])
        self.g_ratio[batch_id, self.i_low, self.j] = self.func(self.ratio[batch_id, self.i_low, self.j])
        self.g_ratio[batch_id, self.i, self.j_left] = self.func(self.ratio[batch_id, self.i, self.j_left])
        self.g_ratio[batch_id, self.i, self.j_right] = self.func(self.ratio[batch_id, self.i, self.j_right])

        self.sum_rows[batch_id, self.i] = self.sum_rows[batch_id, self.i] + delta_i
        self.sum_rows[batch_id, self.i_up] = self.sum_rows[batch_id, self.i_up] + delta_i_up
        self.sum_rows[batch_id, self.i_low] = self.sum_rows[batch_id, self.i_low] + delta_i_low
        self.sum_rows = torch.maximum(self.sum_rows, torch.tensor(0., requires_grad=False, device=self.device)) # Check to avoid numerical issues

        self.norm_const = norm_const_new
        self.norm_const = torch.maximum(self.norm_const, torch.tensor(0., requires_grad=False, device=self.device)) # Check to avoid numerical issues

        self.x[batch_id, self.i, self.j] = - self.x[batch_id, self.i, self.j]

        # Compute objective
        scalar = 2. * torch.ones((self.batch,), requires_grad=False, device=self.device) / self.batch
        scalar[(self.batch // 2 + self.batch % 2):] = scalar[(self.batch // 2 + self.batch % 2):]
        self.obj = torch.sum(self.w * self.A * self.T * scalar)
        if self.method == 'lsb2':
            epsilon = 1e-6
            ratio = torch.rand((self.batch, 1, 1, 1), requires_grad=False, device=self.device) * self.max_value
            value_1_inv = self.model(1. / ratio)
            value_2_inv = self.model(1. / (ratio + epsilon))
            regularizer = torch.maximum((ratio * value_1_inv - (ratio + epsilon) * value_2_inv) / epsilon \
                                        , torch.tensor(0., requires_grad=False, device=self.device))
            self.obj = self.obj + self.gamma * torch.mean(regularizer)
            self.optimizer.zero_grad()
        else:
            self.theta.retain_grad() # This is required to retain the gradient for theta (otherwise it is not computed, as requires_grad=True was not for theta)
        self.obj.backward(retain_graph=True) # Keep the whole graph for each call of sample_burn_in

        # Accept or reject trajectory
        id_reject = torch.rand((self.batch,), device=self.device) > self.A
        if id_reject[id_reject == True].nelement() != 0:
            self.x[batch_id[id_reject], :, :] = self.x_temp[batch_id[id_reject], :, :]

        # Update parameters using SGD with Momentum
        if self.method == 'lsb2':
            self.optimizer.step()
        else:
            with torch.no_grad():
                self.velocity = 0.9 * self.velocity + self.theta.grad
                self.theta -= self.lr * self.velocity
                self.theta.grad = None


    ##### Function used to sample a new configuration
    def sample(self):

        batch_id = torch.arange(self.batch, device=self.device)

        # Sampling a pair of indices
        probs = self.sum_rows / torch.sum(self.sum_rows, 1, keepdim=True)
        self.i = torch.multinomial(probs, 1)[:, 0]
        temp = torch.arange(0, self.batch, device=self.device)
        probs = self.g_ratio[temp, self.i, :] / torch.sum(self.g_ratio[temp, self.i, :], 1, keepdim=True)
        self.j = torch.multinomial(probs, 1)[:, 0]

        # Compute the neighbours
        self.i_up = self.i - 1
        self.i_up[self.i_up < 0] = self.n - 1
        self.i_low = self.i + 1
        self.i_low[self.i_low == self.n] = 0
        self.j_left = self.j - 1
        self.j_left[self.j_left < 0] = self.n - 1
        self.j_right = self.j + 1
        self.j_right[self.j_right == self.n] = 0

        # Compute new normalization constant Z(X')
        delta_i = self.func(1./self.ratio[batch_id, self.i, self.j])                                                \
                - self.g_ratio[batch_id, self.i, self.j]                                                            \
                - self.g_ratio[batch_id, self.i, self.j_left]                                                       \
                - self.g_ratio[batch_id, self.i, self.j_right]                                                      \
                + self.func(self.ratio[batch_id, self.i, self.j_left]                                               \
                            * torch.exp(4. * self.lamda * self.x[batch_id, self.i, self.j_left]                     \
                                        * self.x[batch_id, self.i, self.j]))                                        \
                + self.func(self.ratio[batch_id, self.i, self.j_right]                                              \
                            * torch.exp(4. * self.lamda * self.x[batch_id, self.i, self.j_right]                    \
                                        * self.x[batch_id, self.i, self.j]))                    
        delta_i_up = - self.g_ratio[batch_id, self.i_up, self.j]                                                    \
                    + self.func(self.ratio[batch_id, self.i_up, self.j]                                             \
                                * torch.exp(4. * self.lamda * self.x[batch_id, self.i_up, self.j]                   \
                                            * self.x[batch_id, self.i, self.j]))
        delta_i_low = - self.g_ratio[batch_id, self.i_low, self.j]                                                  \
                    + self.func(self.ratio[batch_id, self.i_low, self.j]                                            \
                                * torch.exp(4. * self.lamda * self.x[batch_id, self.i_low, self.j]                  \
                                            * self.x[batch_id, self.i, self.j]))
        norm_const_new = self.norm_const + delta_i + delta_i_up + delta_i_low

        self.accept = torch.minimum(torch.tensor(1., device=self.device), self.norm_const / norm_const_new)

        id_accept = torch.rand((self.batch,), device=self.device) < self.norm_const / norm_const_new

        # Acceptance
        if id_accept[id_accept == True].nelement() != 0:
            self.ratio[batch_id[id_accept], self.i[id_accept], self.j[id_accept]] =                                 \
                            1./self.ratio[batch_id[id_accept], self.i[id_accept], self.j[id_accept]]
            self.ratio[batch_id[id_accept], self.i_up[id_accept], self.j[id_accept]] =                              \
                            self.ratio[batch_id[id_accept], self.i_up[id_accept], self.j[id_accept]]                \
                            * torch.exp(4. * self.lamda                                                             \
                                        * self.x[batch_id[id_accept], self.i_up[id_accept], self.j[id_accept]]      \
                                        * self.x[batch_id[id_accept], self.i[id_accept], self.j[id_accept]])
            self.ratio[batch_id[id_accept], self.i_low[id_accept], self.j[id_accept]] =                             \
                            self.ratio[batch_id[id_accept], self.i_low[id_accept], self.j[id_accept]]               \
                            * torch.exp(4. * self.lamda                                                             \
                                        * self.x[batch_id[id_accept], self.i_low[id_accept], self.j[id_accept]]     \
                                        * self.x[batch_id[id_accept], self.i[id_accept], self.j[id_accept]])
            self.ratio[batch_id[id_accept], self.i[id_accept], self.j_left[id_accept]] =                            \
                            self.ratio[batch_id[id_accept], self.i[id_accept], self.j_left[id_accept]]              \
                            * torch.exp(4. * self.lamda                                                             \
                                        * self.x[batch_id[id_accept], self.i[id_accept], self.j_left[id_accept]]    \
                                        * self.x[batch_id[id_accept], self.i[id_accept], self.j[id_accept]])
            self.ratio[batch_id[id_accept], self.i[id_accept], self.j_right[id_accept]] =                           \
                            self.ratio[batch_id[id_accept], self.i[id_accept], self.j_right[id_accept]]             \
                            * torch.exp(4. * self.lamda                                                             \
                                        * self.x[batch_id[id_accept], self.i[id_accept], self.j_right[id_accept]]   \
                                        * self.x[batch_id[id_accept], self.i[id_accept], self.j[id_accept]])

            self.g_ratio[batch_id[id_accept], self.i[id_accept], self.j[id_accept]] =                               \
                            self.func(self.ratio[batch_id[id_accept], self.i[id_accept], self.j[id_accept]])
            self.g_ratio[batch_id[id_accept], self.i_up[id_accept], self.j[id_accept]] =                            \
                            self.func(self.ratio[batch_id[id_accept], self.i_up[id_accept], self.j[id_accept]])
            self.g_ratio[batch_id[id_accept], self.i_low[id_accept], self.j[id_accept]] =                           \
                            self.func(self.ratio[batch_id[id_accept], self.i_low[id_accept], self.j[id_accept]])
            self.g_ratio[batch_id[id_accept], self.i[id_accept], self.j_left[id_accept]] =                          \
                            self.func(self.ratio[batch_id[id_accept], self.i[id_accept], self.j_left[id_accept]])
            self.g_ratio[batch_id[id_accept], self.i[id_accept], self.j_right[id_accept]] =                         \
                            self.func(self.ratio[batch_id[id_accept], self.i[id_accept], self.j_right[id_accept]])

            self.sum_rows[batch_id[id_accept], self.i[id_accept]] =                                                 \
                            self.sum_rows[batch_id[id_accept], self.i[id_accept]] + delta_i[id_accept]
            self.sum_rows[batch_id[id_accept], self.i_up[id_accept]] =                                              \
                            self.sum_rows[batch_id[id_accept], self.i_up[id_accept]] + delta_i_up[id_accept]
            self.sum_rows[batch_id[id_accept], self.i_low[id_accept]] =                                             \
                            self.sum_rows[batch_id[id_accept], self.i_low[id_accept]] + delta_i_low[id_accept]
            self.sum_rows = torch.maximum(self.sum_rows, torch.tensor(0., device=self.device)) # Check to avoid numerical issues

            self.norm_const[id_accept] = norm_const_new[id_accept]
            self.norm_const = torch.maximum(self.norm_const, torch.tensor(0., device=self.device)) # Check to avoid numerical issues

            self.x[batch_id[id_accept], self.i[id_accept], self.j[id_accept]] =                                     \
                            - self.x[batch_id[id_accept], self.i[id_accept], self.j[id_accept]]


    ##### Balancing functions
    def barker(self, inp):
        return inp / (1 + inp)

    def sqrt(self, inp):
        return torch.sqrt(inp)

    def minim(self, inp):
        return torch.minimum(torch.tensor(1., device=self.device), inp)

    def maxim(self, inp):
        return torch.maximum(torch.tensor(1., device=self.device), inp)
    
    def linear_normalized(self, inp):
        if self.burn_in == True:
            normalized = self.softmax(self.theta + torch.randn_like(self.theta))
        else:
            normalized = self.softmax(self.theta)
        return normalized[0] * self.barker(inp) + \
               normalized[1] * self.sqrt(inp) +   \
               normalized[2] * self.minim(inp) +  \
               normalized[3] * self.maxim(inp)

    def nonlinear(self, inp):
        length = 0
        inp = inp.unsqueeze(1)
        if len(inp.shape) == 3:
            length = 3
            inp = inp.unsqueeze(1)
        elif len(inp.shape) == 2:
            length = 2
            inp = inp.unsqueeze(1).unsqueeze(1)

        value_1 = self.model(inp)
        value_2 = torch.maximum(inp, torch.tensor(0., requires_grad=False, device=self.device)) * self.model(1. / inp)

        result = torch.minimum(value_1, value_2).squeeze(1)
        
        if length == 3:
            result = result.squeeze(1)
        elif length == 2:
            result = result.squeeze(1).squeeze(1)
        return result

class MonoNet(torch.nn.Module):
    
    def __init__(self, hidden, block, device):
        super(MonoNet, self).__init__()
        self.hidden = hidden
        self.block = block
        init_value = 1e-3
        self.weights = torch.nn.Parameter(init_value * torch.randn(self.hidden, 1, 1, 1, device=device))
        self.bias = torch.nn.Parameter(init_value * torch.randn(self.hidden, device=device))
        self.softplus = torch.nn.Softplus()
        self.max = torch.nn.MaxPool1d(self.block, self.block, 0, 1)

    # inp is 4d tensor in NCHW format
    def forward(self, inp):
        b, c, h, w = inp.shape
        inp = 1. / (1 + 1./inp) - 0.5 # Monotonic nonlinearity to map [0,infty) to [-0.5,0.5) 
        inp = torch.nn.functional.conv2d(inp, self.softplus(self.weights), bias=self.bias)
        inp = self.softplus(inp)
        inp = torch.flatten(inp, 2)
        inp = inp.permute(0, 2, 1)
        inp = self.max(inp)
        inp = torch.min(inp, dim=2, keepdim=True)[0]
        inp = inp.permute(0, 2, 1)
        inp = torch.reshape(inp, (b, c, h, w))
        return inp
